{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "64468ad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.preprocessing.RIVAL10 import preprocessing_rival10\n",
    "from src.utils import *\n",
    "from src.config import PROJECT_ROOT\n",
    "from src.training import run_epoch_x_to_c\n",
    "\n",
    "from src.utils import find_class_imbalance\n",
    "from src.config import RIVAL10_CONFIG\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c60c4e73",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIMMED_CONCEPTS = RIVAL10_CONFIG['N_TRIMMED_CONCEPTS']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "62f39434-7c40-43b6-bde9-193b7f3b4287",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mps.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7bff13a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 26384 unique images.\n",
      "Found 18 unique concepts.\n",
      "Generated one-hot training matrix with shape: (21098, 10)\n",
      "Found 21098 images.\n",
      "Processing in 330 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 330/330 [00:53<00:00,  6.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 21098 images.\n",
      "Found 5286 images.\n",
      "Processing in 83 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 83/83 [00:15<00:00,  5.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 5286 images.\n",
      "Dataset initialized with 21098 pre-sorted items.\n",
      "Dataset initialized with 5286 pre-sorted items.\n",
      "Split train dataset: 16878 training samples, 4220 validation samples\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "concept_labels, train_loader, val_loader, test_loader = preprocessing_rival10(training=False, class_concepts=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18e2cf9e",
   "metadata": {},
   "source": [
    "**Find device to run model on (CPU or GPU).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ca5893fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available()\n",
    "                    else \"mps\" if torch.backends.mps.is_available()\n",
    "                    else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16bc3f7e",
   "metadata": {},
   "source": [
    "### Loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2f33d217",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_weighted_loss = True # Set to False for simple unweighted loss\n",
    "\n",
    "if use_weighted_loss:\n",
    "    concept_weights = find_class_imbalance(concept_labels)\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))\n",
    "                    for ratio in concept_weights]\n",
    "else:\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_TRIMMED_CONCEPTS)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1c567dd5-545b-4d38-ab3c-3a19a028a40b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_outputs_as_array(outputs, n_concepts):\n",
    "    # Initialize an empty list to collect batches\n",
    "    batch_results = []\n",
    "\n",
    "    for i in range(len(outputs)):\n",
    "        batch_size = outputs[i].shape[0]\n",
    "\n",
    "        # Create a batch matrix with N_CONCEPTS number of columns\n",
    "        batch_matrix = np.zeros((batch_size, n_concepts))\n",
    "\n",
    "        for instance_idx in range(batch_size):\n",
    "            # Extract, convert, and flatten data for the current concept\n",
    "            instance_data = outputs[i][instance_idx].detach().cpu().numpy().flatten()\n",
    "            batch_matrix[instance_idx, :] = instance_data\n",
    "\n",
    "        # Add this consistently shaped batch matrix to our collection\n",
    "        batch_results.append(batch_matrix)\n",
    "\n",
    "    return np.vstack(batch_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f7dda06",
   "metadata": {},
   "source": [
    "# Load trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "61fd0fe6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best model loaded.\n"
     ]
    }
   ],
   "source": [
    "best_model = os.path.join(PROJECT_ROOT, 'models', 'RIVAL10', 'best_model.pth')\n",
    "model = torch.load(best_model, map_location=device, weights_only=False)\n",
    "print(\"Best model loaded.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0856ebd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def get_outputs(loader, split_name):\n",
    "    if loader:\n",
    "        with torch.no_grad():\n",
    "            shuffled_concept_labels = []\n",
    "            shuffled_img_labels = []\n",
    "\n",
    "            # Iterate through all batches with a progress bar\n",
    "            for batch in tqdm(loader, desc=f\"Processing {split_name} data\"):\n",
    "                _, concept_labels, image_labels, _ = batch\n",
    "                # Append batch labels to our list\n",
    "                shuffled_concept_labels.append(concept_labels)\n",
    "                shuffled_img_labels.append(image_labels)\n",
    "\n",
    "            # Concatenate all batches into a single tensor\n",
    "            shuffled_concept_labels = torch.cat(shuffled_concept_labels, dim=0)\n",
    "            shuffled_img_labels = torch.cat(shuffled_img_labels, dim=0)\n",
    "\n",
    "            test_loss, test_acc, outputs = run_epoch_x_to_c(\n",
    "                model, loader, attr_criterion, optimizer=None, n_concepts=N_TRIMMED_CONCEPTS, device=device,\n",
    "                return_outputs='sigmoid', verbose=True\n",
    "            )\n",
    "\n",
    "    # print(f\"Shuffled labels shape: {shuffled_img_labels.shape}\")\n",
    "    output_dir = os.path.join(PROJECT_ROOT, 'output', 'RIVAL10')\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    np.save(os.path.join(output_dir, f'C_{split_name}.npy'), shuffled_concept_labels)\n",
    "    np.save(os.path.join(output_dir, f'Y_{split_name}.npy'), shuffled_img_labels)\n",
    "    print(f'Best Model Summary   | Loss: {test_loss:.4f} | Acc: {test_acc:.3f}')\n",
    "\n",
    "    output_array = get_outputs_as_array(outputs, N_TRIMMED_CONCEPTS)\n",
    "    print(f\"Final shape: {output_array.shape}\")\n",
    "\n",
    "    np.save(os.path.join(output_dir, f'C_hat_sigmoid_{split_name}.npy'), output_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f2682697",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing train data: 100%|██████████| 264/264 [01:37<00:00,  2.70it/s]\n",
      "                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Model Summary   | Loss: 0.0317 | Acc: 99.739\n",
      "Final shape: (16878, 18)\n"
     ]
    }
   ],
   "source": [
    "get_outputs(train_loader, 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1aee2216",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing test data: 100%|██████████| 83/83 [00:16<00:00,  5.07it/s]\n",
      "                                                                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Model Summary   | Loss: 0.0381 | Acc: 99.709\n",
      "Final shape: (5286, 18)\n"
     ]
    }
   ],
   "source": [
    "get_outputs(test_loader, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40672643-b9a6-4d6e-8769-8fc9b4dd43fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d8ff41-8f71-417b-8f4d-7a113434b2ff",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
